package com.hoppinzq;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.hoppinzq.bean.*;
import com.hoppinzq.embedBean.Embed;
import com.hoppinzq.embedBean.EmbeddingData;
import org.apache.commons.io.FileUtils;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;

import java.io.*;
import java.util.*;

public class demo2 {

    public static void main(String[] args) throws IOException {
        //有一个人
        User user=new User("id_1","zhangqi");
        //他在注册是时候说他偏爱东方和学习
        List<String> userPrefer=new ArrayList<String>();
        userPrefer.add("学习");
        userPrefer.add("东方");
        //他的用户画像显示他偏爱东方和学习
        UserPortrait userPortrait=new UserPortrait(user, userPrefer);
        //用户浏览了两个小时的视频——苇名弦一郎，其类别为只狼
        Action action=new Action("action_id_1","浏览","2023-05-05 11:00:00","2023-05-05 13:00:00","2hour");
        Video video=new Video("116","苇名弦一郎","苇名，就是我的一切……为此，我愿……放弃为人","只狼");
        UserVideoAction userVideoAction=new UserVideoAction(user,action,video,"2023-05-05 13:00:00");
        //上面是初始化过程，简单说明：一个偏爱东方和学习的用户zhangqi，看了两个小时的视频“苇名弦一郎”，视频的类别是只狼
        //可添加条件————若观看时长超过5小时或看了5个只狼的视频，则在用户画像（偏爱）处添加只狼

        initVideo();
        getRecommendVideo(video.getVideo_name()+video.getVideo_miaoshu()+"学习,东方",10);
    }

    public static void initVideo() throws IOException {
        RedisDemo redisDemo=new RedisDemo();
        JedisPool jedisPool=redisDemo.getJedisPool();

        try (Jedis jedis = jedisPool.getResource()) {
            boolean exists = jedis.exists("demo:videos");
            if(exists)
                return;

            File textFile = new File(Thread.currentThread().getContextClassLoader().getResource("video.json").getPath().replace("%20", " "));
            String text = FileUtils.readFileToString(textFile, "utf-8");
            JSONArray videos= JSON.parseArray(text);
            List<String> video2Embeddings=new ArrayList<>();
            for(int videoIndex=0;videoIndex<videos.size();videoIndex++){
                JSONObject video=videos.getJSONObject(videoIndex);
                video2Embeddings.add(video.get("video_name").toString()+video.get("video_miaoshu")+video.get("video_fenlei_name")+video.get("video_fenlei_miaoshu"));
            }
            Embed embedAll=Embedding.getEmbed(video2Embeddings);
            EmbeddingData[] embeddingAllQuestions=embedAll.getData();
            for(int i=0;i<videos.size();i++){
                JSONObject video=videos.getJSONObject(i);
                video.put("embedding",embeddingAllQuestions[i].getEmbedding());
                jedis.rpush("demo:videos", video.toJSONString());
            }
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            redisDemo.close();
        }
    }

    /**
     * 获取问题，这里模拟的大数据或云计算平台
     * 因为你可以看到，我是拿问的问题去跟所有问题的embedding求余弦相似度，如果有百万或亿级数据的话，要去求数百万次余弦相似度（这里可以用向量化计算或者并行计算的算法去优化，也可以简单的开辟多个线程求）。
     * @param question
     */
    private static void getRecommendVideo(String question,int number){

        RedisDemo redisDemo=new RedisDemo();
        JedisPool jedisPool=redisDemo.getJedisPool();

        try (Jedis jedis = jedisPool.getResource()) {
            boolean exists = jedis.exists("demo:videos");
            if(!exists)
                return;
            Embed embedQuestions=Embedding.getEmbed(question);
            double[] embedding=embedQuestions.getData()[0].getEmbedding();
            List<String> videos=jedis.lrange("demo:videos", 0, -1);
            List<JSONObject> videoSimilarity=new ArrayList<>();
            for(int i=0;i<videos.size();i++){
                JSONObject jsonObject=JSONObject.parseObject(videos.get(i));
                JSONArray embeddingArrays=(JSONArray)jsonObject.get("embedding");
                jsonObject.put("embedding",cosineSimilarity(embedding,getDoubleArray(embeddingArrays)));
                JSONObject newVideo=new JSONObject();
                newVideo.put("embedding",cosineSimilarity(embedding,getDoubleArray(embeddingArrays)));
                newVideo.put("video_name",jsonObject.get("video_name"));
                newVideo.put("video_miaoshu",jsonObject.get("video_miaoshu"));
                videoSimilarity.add(newVideo);
            }

            /**
             * 排序，从大到小
             */
            Comparator<JSONObject> comparator = new Comparator<JSONObject>() {
                @Override
                public int compare(JSONObject o1, JSONObject o2) {
                    try {
                        double embedding1 = o1.getDouble("embedding");
                        double embedding2 = o2.getDouble("embedding");
                        if (embedding1 > embedding2) {
                            return -1;
                        } else if (embedding1 < embedding2) {
                            return 1;
                        } else {
                            return 0;
                        }
                    } catch (JSONException e) {
                        e.printStackTrace();
                        return 0;
                    }
                }
            };
            Collections.sort(videoSimilarity, comparator);
            System.err.println("下面是推荐的10个视频：");
            System.err.println(formatJson(JSONArray.toJSONString(videoSimilarity.subList(0,number))));
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            redisDemo.close();
        }
    }

    /**
     * 余弦相似度算法
     * @param vector1
     * @param vector2
     * @return
     */
    private static double cosineSimilarity(double[] vector1, double[] vector2) {
        RealVector realVector1 = new ArrayRealVector(vector1);
        RealVector realVector2 = new ArrayRealVector(vector2);
        double dotProduct = realVector1.dotProduct(realVector2);
        double norm = realVector1.getNorm() * realVector2.getNorm();
        double similarity = dotProduct / norm;
        return similarity;
    }

    /**
     * jsonarray转double数组
     * @param jsonArray
     * @return
     */
    private static double[] getDoubleArray(JSONArray jsonArray){
        double[] array = new double[jsonArray.size()];
        for (int i = 0; i < jsonArray.size(); i++) {
            try {
                array[i] = jsonArray.getDouble(i);
            } catch (JSONException e) {
                e.printStackTrace();
                array[i] = 0.0;
            }
        }
        return array;
    }

    /**
     * 返回指定条数的array
     * @param jsonArray
     * @param num
     * @return
     */
    private static JSONArray getNumArray(JSONArray jsonArray,int num){
        JSONArray resultArray = new JSONArray();
        for (int i = 0; i < num && i < jsonArray.size(); i++) {
            resultArray.add(jsonArray.getJSONObject(i));
        }
        return resultArray;
    }

    /**
     * 格式化json
     * @param jsonStr
     * @return
     */
    private static String formatJson(String jsonStr) {
        if (null == jsonStr || "".equals(jsonStr)) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        char last = '\0';
        char current = '\0';
        int level = 0;
        boolean indentFlag = false;
        for (int i = 0; i < jsonStr.length(); i++) {
            last = current;
            current = jsonStr.charAt(i);
            switch (current) {
                case '{':
                case '[':
                    sb.append(current);
                    sb.append('\n');
                    level++;
                    addIndentBlank(sb, level);
                    break;
                case '}':
                case ']':
                    sb.append('\n');
                    level--;
                    addIndentBlank(sb, level);
                    sb.append(current);
                    break;
                case ',':
                    sb.append(current);
                    if (last != '\\') {
                        sb.append('\n');
                        addIndentBlank(sb, level);
                    }
                    break;
                default:
                    sb.append(current);
            }
        }
        return sb.toString();
    }

    /**
     * 添加制表符
     * @param sb
     * @param level
     */
    private static void addIndentBlank(StringBuilder sb, int level) {
        for (int i = 0; i < level; i++) {
            sb.append('\t');
        }
    }

}
